import logging
import argparse
import os

import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

from data import Dataset
from network.Network import ResNet18ARM
from network.Network2 import ResNet18ARMMasking
from utils import image


def parse():
    parser = argparse.ArgumentParser(description='My Awesome FER')
    parser.add_argument('--cuda', default=True, type=bool, help='enable cuda or not')

    parser.add_argument('--epochs', default=70, type=int, help='epochs for training')
    parser.add_argument('--batch_size', default=256, type=int, help='batch size for training')
    parser.add_argument('--test_batch_size', default=64, type=int, help='batch size of testing')
    parser.add_argument('--optimizer', default='adam', type=str, help='optimizer adam or sgd')
    parser.add_argument('--drop_rate', default=0.5, type=float, help='drop rate')

    parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate for sgd')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum for sgd')

    parser.add_argument('--rafdb_base', default='E:/MLPR/RAF-DB/basic', type=str,
                        help='path to RAF-DB basic directory')
    parser.add_argument('--sfew_base', default='E:/MLPR/SFEW', type=str,
                        help='path to SFEW base directory')

    parser.add_argument('--log_dir', default='./runs/', type=str,
                        help='path to log dir')
    parser.add_argument('--checkpoint_dir', default='./model/', type=str,
                        help='path to save checkpoints')

    parser.add_argument('--resume', default=False, type=bool,
                        help='whether to resume from checkpoint')
    parser.add_argument('--resume_checkpoint', default='X:/Network/2/epoch8_acc0.76508.pth', type=str,
                        help='checkpoint to resume')

    parser.add_argument('--feed_huawei', default='E:/resnet', type=str,
                        help='feed local cache: Solve the problem that huawei cloud CANNOT access PyTorch')

    return parser.parse_args()


def main():
    arg = parse()

    # Prepare Data
    logging.info('[Pre] Preparing Data')
    train_dataloader, test_dataloader = Dataset.prepare_train(arg)

    # img, label = train_dataloader.__iter__().next()
    # label_map = {0: 'Surprise', 1: 'Fear', 2: 'Disgust', 3: 'Happiness', 4: 'Sadness', 5: 'Anger', 6: 'Neutral'}
    # image.imshow(img[0], label_map[label[0].item()])

    # Build Network
    logging.info('[Pre] Building Network')
    model = ResNet18ARMMasking(pretrained=True,
                               num_classes=7,
                               drop_rate=arg.drop_rate,
                               feed_huawei=arg.feed_huawei)

    # Cuda
    if arg.cuda:
        model.cuda()

    # Optimizer
    if arg.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
    elif arg.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), arg.learning_rate, momentum=arg.momentum, weight_decay=1e-4)
    else:
        raise ValueError('optimizer can only be adam or sgd')

    # Checkpoint
    if arg.resume:
        checkpoint = torch.load(arg.resume_checkpoint)
        epoch_start = checkpoint['epoch']
        best_accuracy = checkpoint['best_acc']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        epoch_start = 0
        best_accuracy = 0.0

    # Loss function & Scheduler
    loss_fn = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    # Tensorboard
    writer = SummaryWriter(arg.log_dir)
    writer.add_graph(model, torch.randn(arg.batch_size, 3, 224, 224).cuda())

    # Train & Test Loop
    logging.info('[train] Begin')
    for epoch in range(epoch_start, arg.epochs):
        logging.info('[train] Epoch: %d', epoch + 1)

        # Train
        model.train()
        train_size = len(train_dataloader.dataset)
        train_batch_size = train_dataloader.batch_size
        train_correct = 0
        for batch, (X, y) in enumerate(train_dataloader):
            X = X.cuda()
            y = y.cuda()

            # Compute prediction and loss
            pred, _ = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Summary
            _, predicts = torch.max(pred, 1)
            train_correct += torch.eq(predicts, y).sum()
            if batch % 10 == 0:  # every 10 batches
                current_idx = (batch + 1) * train_batch_size
                logging.info(f"loss: {loss.item():>7f}  [{current_idx:>5d}/{train_size:>5d}]")
                global_idx = epoch * train_size + current_idx
                writer.add_scalar('Loss/Train', loss.item(), global_idx)

        train_correct = train_correct.__float__() / train_size
        writer.add_scalar('Accuracy/Train', train_correct, epoch)
        logging.info(f'Train: Accuracy: {(100 * train_correct):>0.3f}%')
        scheduler.step()

        # Test
        model.eval()
        test_size = len(test_dataloader.dataset)
        test_batch_size = test_dataloader.batch_size
        test_loss, test_correct = 0, 0
        with torch.no_grad():
            for X, y in test_dataloader:
                X = X.cuda()
                y = y.cuda()

                pred, _ = model(X)

                test_loss += loss_fn(pred, y).item()
                _, predicts = torch.max(pred, 1)
                test_correct += torch.eq(predicts, y).sum()

        test_loss = test_loss * test_batch_size / test_size
        test_correct = test_correct.__float__() / test_size
        logging.info(f"Test: Accuracy: {(100 * test_correct):>0.3f}%, Avg loss: {test_loss:>8f}")
        writer.add_scalar('Accuracy/Test', test_correct, epoch)
        writer.add_scalar('Loss/Test', test_loss, epoch)
        writer.flush()

        # Save model when acc > 0.75 and of the best in each epochs
        if test_correct > 0.75 and test_correct > best_accuracy:
            best_accuracy = test_correct
            torch.save({'epoch': epoch,
                        'best_acc': test_correct,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(), },
                       os.path.join(arg.checkpoint_dir, f"epoch{epoch}_acc{test_correct:>0.5f}.pth"))

    logging.info('[train] Done')

    writer.close()


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    main()
